using System;
using System.Runtime.CompilerServices;
using static Util;
using size_t = nuint;

unsafe interface IMatrixMulSimdSpecific : IDisposable
{
	size_t BlockRows { get; }
	size_t BlockCols { get; }
	size_t KernelRows { get; }
	size_t KernelCols { get; }

	void PackVPanel(double* dst, double* src, size_t src_rows, size_t src_cols, size_t src_stride);
	void PackBlock(double* dst, double* src, size_t src_rows, size_t src_cols, size_t src_stride);

	void MulKernelComplete(bool add, double* C, double* A, double* B,
						   size_t A_cols, size_t C_stride);
	void MulKernelIncomplete(bool add, double* C, double* A, double* B,
							 size_t C_rows, size_t C_cols,
							 size_t A_cols, size_t C_stride);

}

unsafe class MatrixMulCommon<TSimd> : IMatrixMultiplier where TSimd : IMatrixMulSimdSpecific, new()
{
	TSimd simd = new TSimd();

	public bool TryLargePage { get; set; }

	private bool disposedValue;

	double* B_packed, A_packed;
	size_t totalSize;
	bool isLargePage;

	public void Multiply(bool add, double* C, double* A, double* B, size_t C_rows, size_t C_cols, size_t A_cols, size_t C_stride, size_t A_stride, size_t B_stride)
	{
		if (disposedValue)
			throw new ObjectDisposedException("");

		PrepareTempArrays(C_rows);

		if (A_cols <= simd.BlockRows)
		{
			MulVPanelHPanel(add, C, A, B, C_rows, C_cols, A_cols, C_stride, A_stride, B_stride);
		}
		else
		{
			MulVPanelHPanel(add, C, A, B, C_rows, C_cols, simd.BlockRows, C_stride, A_stride, B_stride);

			size_t k = simd.BlockRows;
			for (; k + simd.BlockRows <= A_cols; k += simd.BlockRows)
			{
				MulVPanelHPanel(true, C, A + k, B + k * B_stride,
									C_rows, C_cols, simd.BlockRows,
									C_stride, A_stride, B_stride);
			}
			if (k < A_cols)
			{
				size_t A_cols_rem = A_cols - k;
				MulVPanelHPanel(true, C, A + k, B + k * B_stride,
									C_rows, C_cols, A_cols_rem,
									C_stride, A_stride, B_stride);
			}
		}

	}

	void MulVPanelHPanel(bool add, double* C, double* A, double* B, size_t C_rows, size_t C_cols, size_t A_cols, size_t C_stride, size_t A_stride, size_t B_stride)
	{
		simd.PackVPanel(A_packed, A, C_rows, A_cols, A_stride);

		size_t j = 0;
		for (; j + simd.BlockCols <= C_cols; j += simd.BlockCols)
		{
			MulVPanelBlock(add, C + j, B + j, C_rows, simd.BlockCols, A_cols, C_stride, B_stride);
		}

		if (j < C_cols)
		{
			size_t C_cols_rem = C_cols - j;
			MulVPanelBlock(add, C + j, B + j, C_rows, C_cols_rem, A_cols, C_stride, B_stride);
		}
	}

	[MethodImpl(MethodImplOptions.AggressiveOptimization)]
	void MulVPanelBlock(bool add, double* C, double* B, size_t C_rows, size_t C_cols, size_t A_cols, size_t C_stride, size_t B_stride)
	{
		simd.PackBlock(B_packed, B, A_cols, C_cols, B_stride);

		size_t K_rows = simd.KernelRows;
		size_t K_cols = simd.KernelCols;

		size_t C_cols_rem = C_cols % K_cols;
		size_t C_rows_rem = C_rows % K_rows;
		size_t i = 0;
		for (; i + K_rows <= C_rows; i += K_rows)
		{
			size_t j = 0;
			for (; j + K_cols <= C_cols; j += K_cols)
			{
				simd.MulKernelComplete(add, C + i * C_stride + j,
									   A_packed + i * A_cols,
									   B_packed + j * A_cols,
									   A_cols, C_stride);
			}
			if (C_cols_rem > 0)
			{
				simd.MulKernelIncomplete(add, C + i * C_stride + j,
										 A_packed + i * A_cols,
										 B_packed + j * A_cols,
										 K_rows, C_cols_rem,
										 A_cols, C_stride);
			}
		}
		if (C_rows_rem > 0)
		{
			size_t j = 0;
			for (; j + K_cols <= C_cols; j += K_cols)
			{
				simd.MulKernelIncomplete(add, C + i * C_stride + j,
										 A_packed + i * A_cols,
										 B_packed + j * A_cols,
										 C_rows_rem, K_cols,
										 A_cols, C_stride);
			}
			if (C_cols_rem > 0)
			{
				simd.MulKernelIncomplete(add, C + i * C_stride + j,
										 A_packed + i * A_cols,
										 B_packed + j * A_cols,
										 C_rows_rem, C_cols_rem,
										 A_cols, C_stride);
			}
		}
	}

	void PrepareTempArrays(size_t A_rows)
	{
		size_t B_size = simd.BlockRows * simd.BlockCols;
		B_size = RoundUpPow2(B_size, CachelineBytes / sizeof(double));

		size_t A_size = A_rows * simd.BlockRows;
		if (totalSize >= B_size + A_size)
			return;

		if (B_packed != null)
		{
			FreeTempArrays();
		}
		if (TryLargePage)
		{
			totalSize = RoundUpPow2(B_size + A_size, LargePageBytes / sizeof(double));
			B_packed = TryAllocLargePages(totalSize);
			if (B_packed != null)
			{
				A_packed = B_packed + B_size;
				isLargePage = true;
				return;
			}
		}
		totalSize = B_size + A_size;
		B_packed = CacheAlignAlloc(totalSize);
		A_packed = B_packed + B_size;
		isLargePage = false;
	}

	void FreeTempArrays()
	{
		if (isLargePage)
			FreeLargePages(B_packed, totalSize);
		else
			CacheAlignFree(B_packed);
	}

	protected virtual void Dispose(bool disposing)
	{
		if (!disposedValue)
		{
			if (disposing)
			{
				simd.Dispose();
			}
			if (B_packed != null)
				FreeTempArrays();
			simd = default;
			disposedValue = true;
		}
	}

	~MatrixMulCommon()
	{
		Dispose(disposing: false);
	}

	public void Dispose()
	{
		Dispose(disposing: true);
		GC.SuppressFinalize(this);
	}
}
